import copy
import logging
from collections import deque, defaultdict

import jax
import clip
import torch
import jax.numpy as jnp
from PIL import Image
import numpy as np
from tqdm.auto import tqdm, trange

from .vl_reward import (
    device,
    get_torch_clip_reward,
    get_torch_clip_adapter_reward,
    get_torch_ts2net_reward,
    get_torch_mugen_reward
)


def batch_rollout(
    rng,
    data_aug_rng,
    env,
    policy_fn,
    transform_obs_fn,
    transform_action_fn,
    episode_length=2500,
    log_interval=None,
    window_size=0,
    video_window_size=12,
    num_episodes=1,
    num_actions=8,
    return_to_go=100.0,
    scale=100.0,
    clip_model=None,
    vl_type="clip",
    pos_text=None,
    neg_text=None,
    reward_mean=0.0,
    reward_std=1.0,
    use_normalize=False,
    high=0.0,
    low=0.0,
    use_rotation=False
):
    concat_fn = lambda x, y: jnp.concatenate([x, y], axis=1)
    video_trim_fn = lambda x: x[:, -video_window_size:, ...]
    trim_fn = lambda x: x[:, -window_size:, ...]
    batch_fn = lambda x: x[None, None, ...]

    # action_space = env.action_space.shape[0]

    def prepare_input(all_inputs, obs, rtg):
        action = jnp.zeros(num_actions)
        inputs = {**obs, "action": action, "rtg": rtg}
        inputs = jax.tree_util.tree_map(batch_fn, inputs)

        if len(all_inputs) == 0:
            inputs = inputs
        else:
            all_inputs_copy = copy.deepcopy(all_inputs)

            inputs = jax.tree_util.tree_map(concat_fn, all_inputs_copy, inputs)
            inputs = jax.tree_util.tree_map(trim_fn, inputs)

        return all_inputs, inputs

    def update_input(all_inputs, obs, action, rtg):
        inputs = {**obs, "action": action, "rtg": rtg}
        inputs = jax.tree_util.tree_map(batch_fn, inputs)
        if len(all_inputs) == 0:
            all_inputs = inputs
        else:
            all_inputs = jax.tree_util.tree_map(concat_fn, all_inputs, inputs)
            # all_inputs = jax.tree_util.tree_map(trim_fn, all_inputs)
            all_inputs = jax.tree_util.tree_map(video_trim_fn, all_inputs)

        return all_inputs

    def update_preprocessed_video(video_stack):
        if vl_type == "ts2net":
            return torch.stack(list(video_stack)).to(device)
        elif vl_type == "mugen":
            return torch.from_numpy(np.asarray(video_stack)).unsqueeze(0).to(device)
        else:
            raise ValueError

    reward = jnp.zeros(1, dtype=jnp.float32)
    ep_lens = jnp.zeros(1, dtype=jnp.float32)
    success = jnp.zeros(1, dtype=jnp.float32)

    videos = []
    if clip_model[0] is not None:
        if vl_type == "ts2net":
            def get_input(text):
                input_ids = clip.tokenize(text).to(device)
                input_mask = [1] * len(input_ids)
                segment_ids = [0] * len(input_ids)
                input_mask, segment_ids = map(lambda x: torch.LongTensor(x).to(device), [input_mask, segment_ids])
                return input_ids, input_mask, segment_ids

            pos_input_ids, pos_input_mask, pos_segment_ids = get_input(pos_text)
            neg_input_ids, neg_input_mask, neg_segment_ids = get_input(neg_text)
            model, _ = clip_model
            pos_seq_output = model.get_sequence_output(pos_input_ids, pos_segment_ids, pos_input_mask)
            neg_seq_output = model.get_sequence_output(neg_input_ids, neg_segment_ids, neg_input_mask)
        elif vl_type == "mugen":
            model, _ = clip_model
            pos_seq_output = model.get_text_embedding({"text": [pos_text]})
            neg_seq_output = model.get_text_embedding({"text": [neg_text]})
        elif vl_type != "clip":
            raise ValueError
 
    for _ in trange(num_episodes, desc="rollout", ncols=0):
        rtg = {
            key: jnp.full(1, return_to_go / scale, dtype=jnp.float32)
            for key in env.config.image_key.split(", ")
        }
        all_inputs = {}
        done = jnp.zeros(1, dtype=jnp.int32)
        video_stack = defaultdict(lambda: deque([], maxlen=video_window_size))

        for t in trange(episode_length, desc=f"episode {_}", ncols=0, leave=False):
            done_prev = done

            if t == 0:
                obs = env.reset()
                if clip_model[0] is not None:
                    for key in obs['image'].keys():
                        if vl_type == 'ts2net':
                            video_stack[key].extend([clip_model[1](np.asarray(obs['image'][key]))] * video_window_size)
                        elif vl_type == "mugen":
                            video_stack[key].extend([np.asarray(obs['image'][key])] * video_window_size)
                        elif vl_type != "clip":
                            raise ValueError
            else:
                obs = next_obs
                if clip_model[0] is not None:
                    for key in obs['image'].keys():
                        if vl_type == 'ts2net':
                            video_stack[key].append(clip_model[1](np.asarray(obs['image'][key])))
                        elif vl_type == "mugen":
                            video_stack[key].append(np.asarray(obs['image'][key]))
                        elif vl_type != "clip":
                            raise ValueError

            if transform_obs_fn is not None:
                input_obs = copy.deepcopy(obs)
                for key, val in input_obs['image'].items():
                    input_obs['image'][key], data_aug_rng = transform_obs_fn(val, data_aug_rng)
            else:
                input_obs = obs

            all_inputs, inputs = prepare_input(all_inputs, input_obs, rtg)
            action = jax.device_get(policy_fn(inputs=inputs, rngs=rng))[0]
            action = transform_action_fn(action)
            all_inputs = update_input(all_inputs, input_obs, action, rtg)

            pose = action[:3] 
            pose = (pose + 1) / 2 * (high[:3] - low[:3]) + low[:3]
            if not env.config.absolute_mode:
                # Manual handling of overflow in z axis
                curr_pose = env._task._task.robot.arm.get_tip().get_pose()[:3]
                curr_z = curr_pose[2]
                init_z = env._init_pose[2]
                delta_z = pose[2]

                if curr_z + delta_z >= init_z:
                    pose[2] = 0.0

            if use_rotation:
                if not env.config.absolute_mode:
                    target_pose = pose
                    quat = (action[3:7] + 1) / 2 * (high[3:7] - low[3:7]) + low[3:7]
                else:
                    target_pose = curr_pose + pose
                    curr_quat = env._task._task.robot.arm.get_tip().get_pose()[3:]
                    d_quat = (action[3:7] + 1) / 2 * (high[3:7] - low[3:7]) + low[3:7]
                    quat = curr_quat + d_quat
                quat = quat / jnp.linalg.norm(quat)
            else:
                target_pose = pose
                if env.config.absolute_mode:
                    quat = env._task._task.robot.arm.get_tip().get_pose()[3:]
                else:
                    quat = jnp.array([0.0, 0.0, 0.0, 1.0])

            gripper = action[-1]
            gripper = (gripper + 1) / 2 * (high[-1] - low[-1]) + low[-1]

            env_action = jnp.hstack([target_pose, quat, gripper])
            next_obs, _reward, done, info = env.step(env_action)

            reward = reward + _reward * (1 - done_prev)
            # compute implicit reward.
            if clip_model[0] is not None:
                # clip_rewards = {key: 0.0 for key in obs['image'].keys()}
                for key in obs['image'].keys():
                    if vl_type == "clip":
                        clip_reward = get_torch_clip_reward(clip_model, obs['image'][key], pos_text, neg_text)
                    elif vl_type == "ts2net":
                        preprocessed_video = update_preprocessed_video(video_stack[key])
                        clip_reward = get_torch_ts2net_reward(
                            clip_model,
                            preprocessed_video,
                            pos_seq_output,
                            pos_input_mask,
                            neg_seq_output,
                            neg_input_mask
                        )
                    elif vl_type == "mugen":
                        preprocessed_video = update_preprocessed_video(video_stack[key])
                        clip_reward = get_torch_mugen_reward(
                            clip_model,
                            preprocessed_video,
                            pos_seq_output,
                            neg_seq_output,
                        )
                    else:
                        raise ValueError
                    
                        # clip_rewards.append(clip_reward)
                    if use_normalize:
                        rtg[key] -= (clip_reward - reward_mean[key]) / (scale * reward_std[key])
                        # rtg -= (clip_reward - reward_mean) / scale
                    else:
                        rtg[key] -= clip_reward / scale

            done = jnp.logical_or(done, done_prev).astype(jnp.int32)
            success += jnp.array([info["success"]], dtype=jnp.float32)

            if log_interval and t % log_interval == 0:
                logging.info("step: %d done: %s reward: %s", t, done, reward)

            if jnp.all(done):
                break

    metric = {
        "return": reward.astype(jnp.float32) / num_episodes,
        "success": success.astype(jnp.float32) / num_episodes * 100,
    }
    return metric, info
